# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("//jaxlib:symlink_files.bzl", "symlink_files", "symlink_inputs")

package(
    default_visibility = [
        "//visibility:public",
    ],
)

symlink_inputs(
    name = "core",
    rule = py_library,
    symlinked_inputs = {"srcs": {
        "dialects": ["@llvm-project//mlir/python:DialectCorePyFiles"],
    }},
)

symlink_inputs(
    name = "extras",
    rule = py_library,
    symlinked_inputs = {"srcs": {
        "extras": ["@llvm-project//mlir/python:ExtrasPyFiles"],
    }},
    deps = [
        ":ir",
        ":mlir",
    ],
)

symlink_inputs(
    name = "ir",
    rule = py_library,
    symlinked_inputs = {"srcs": {
        ".": ["@llvm-project//mlir/python:IRPyFiles"],
    }},
    deps = [
        ":mlir",
    ],
)

py_library(
    name = "mlir",
    deps = [
        "//jaxlib/mlir/_mlir_libs",
    ],
)

symlink_inputs(
    name = "func_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@llvm-project//mlir/python:FuncPyFiles"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
    ],
)

symlink_inputs(
    name = "vector_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
    ],
)

symlink_inputs(
    name = "math_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
    ],
)

symlink_inputs(
    name = "arithmetic_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
    ],
)

symlink_inputs(
    name = "memref_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
    ],
)

symlink_inputs(
    name = "scf_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@llvm-project//mlir/python:SCFPyFiles"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
    ],
)

symlink_inputs(
    name = "builtin_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"],
    }},
    deps = [
        ":core",
        ":extras",
        ":ir",
        ":mlir",
    ],
)

symlink_inputs(
    name = "chlo_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@stablehlo//:chlo_ops_py_files"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
        "//jaxlib/mlir/_mlir_libs:_chlo",
    ],
)

symlink_inputs(
    name = "sparse_tensor_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
        "//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor",
        "//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses",
    ],
)

symlink_inputs(
    name = "mhlo_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
        "//jaxlib/mlir/_mlir_libs:_mlirHlo",
    ],
)

symlink_inputs(
    name = "pass_manager",
    rule = py_library,
    symlinked_inputs = {"srcs": {
        ".": ["@llvm-project//mlir/python:PassManagerPyFiles"],
    }},
    deps = [
        ":mlir",
    ],
)

symlink_inputs(
    name = "stablehlo_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {
      "dialects": ["@stablehlo//:stablehlo_ops_py_files"],
    }},
    deps = [
        ":core",
        ":ir",
        ":mlir",
        "//jaxlib/mlir/_mlir_libs:_stablehlo",
    ],
)

symlink_inputs(
    name = "execution_engine",
    rule = py_library,
    symlinked_inputs = {"srcs": {
        ".": [
            "@llvm-project//mlir/python:ExecutionEnginePyFiles",
        ],
    }},
    deps = [
        ":mlir",
        "//jaxlib/mlir/_mlir_libs:_mlirExecutionEngine",
    ],
)

symlink_inputs(
    name = "nvgpu_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {"dialects": [
        "@llvm-project//mlir/python:NVGPUOpsPyFiles",
    ]}},
    deps = [
        ":core",
        ":ir",
        ":mlir",
        "//jaxlib/mlir/_mlir_libs:_mlirDialectsNvgpu",
    ],
)

symlink_inputs(
    name = "nvvm_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {"dialects": [
        "@llvm-project//mlir/python:NVVMOpsPyFiles",
    ]}},
    deps = [
        ":core",
        ":ir",
        ":mlir",
    ],
)

symlink_files(
    name = "gpu_files",
    srcs = ["@llvm-project//mlir/python:GPUOpsPyFiles"],
    dst = "dialects",
    flatten = True,
)

symlink_files(
    name = "gpu_package_files",
    srcs = ["@llvm-project//mlir/python:GPUOpsPackagePyFiles"],
    dst = "dialects/gpu",
    flatten = True,
)

symlink_files(
    name = "gpu_package_passes_files",
    srcs = ["@llvm-project//mlir/python:GPUOpsPackagePassesPyFiles"],
    dst = "dialects/gpu/passes",
    flatten = True,
)

py_library(
    name = "gpu_dialect",
    srcs = [
        ":gpu_files",
        ":gpu_package_files",
        ":gpu_package_passes_files",
    ],
    deps = [
        ":core",
        ":ir",
        ":mlir",
        "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU",
        "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses",
    ],
)

symlink_inputs(
    name = "llvm_dialect",
    rule = py_library,
    symlinked_inputs = {"srcs": {"dialects": [
        "@llvm-project//mlir/python:LLVMOpsPyFiles",
    ]}},
    deps = [
        ":core",
        ":ir",
        ":mlir",
        "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM",
    ],
)

